Note
Go to the end to download the full example code.
ReEig Layer Animation#
This animation visualizes how the ReEig (Rectified Eigenvalue) layer introduces non-linearity while preserving the SPD property.
where \(X = U \Lambda U^T\) is the eigendecomposition.
# sphinx_gallery_thumbnail_number = 1
Understanding ReEig#
ReEig applies a ReLU-like function to eigenvalues:
Eigenvalues above threshold \(\epsilon\) are preserved
Eigenvalues below threshold are clamped to \(\epsilon\)
This ensures the output remains SPD while introducing non-linearity, similar to how ReLU introduces non-linearity in standard neural networks.
import sys
from pathlib import Path
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
# Handle both direct execution and import
try:
_current_dir = Path(__file__).parent
except NameError:
_current_dir = Path.cwd() / "examples" / "visualizations"
if str(_current_dir) not in sys.path:
sys.path.insert(0, str(_current_dir))
from spd_visualization_utils import create_ellipse_patch, setup_spd_plot
Setup and Data Generation#
np.random.seed(42)
# Threshold for ReEig
epsilon = 0.3
# Create matrices where some eigenvalues are below threshold
n_matrices = 5
# Generate eigenvalues - some will be below threshold
eigval_sets = [
np.array([0.1, 2.0]), # First below threshold
np.array([0.5, 1.5]), # Both above
np.array([0.05, 0.8]), # First below
np.array([1.0, 0.15]), # Second below
np.array([0.2, 0.25]), # Both below threshold
]
# Random rotation matrices for each
rotation_angles = np.linspace(0, np.pi, n_matrices)
rotation_matrices = [
np.array([[np.cos(a), -np.sin(a)], [np.sin(a), np.cos(a)]]) for a in rotation_angles
]
# Create input matrices: X = U @ diag(eigvals) @ U.T
input_matrices = []
for eigvals, U in zip(eigval_sets, rotation_matrices):
X = U @ np.diag(eigvals) @ U.T
input_matrices.append(X)
# Apply ReEig: clamp eigenvalues to epsilon
output_matrices = []
for eigvals, U in zip(eigval_sets, rotation_matrices):
clamped_eigvals = np.maximum(eigvals, epsilon)
Y = U @ np.diag(clamped_eigvals) @ U.T
output_matrices.append(Y)
# Colors
colors = plt.cm.coolwarm(np.linspace(0.2, 0.8, n_matrices))
print(f"Threshold epsilon = {epsilon}")
for i, (inp, out) in enumerate(zip(eigval_sets, eigval_sets)):
clamped = np.maximum(inp, epsilon)
print(f"Matrix {i+1}: {inp} -> {clamped}")
Threshold epsilon = 0.3
Matrix 1: [0.1 2. ] -> [0.3 2. ]
Matrix 2: [0.5 1.5] -> [0.5 1.5]
Matrix 3: [0.05 0.8 ] -> [0.3 0.8]
Matrix 4: [1. 0.15] -> [1. 0.3]
Matrix 5: [0.2 0.25] -> [0.3 0.3]
Static Visualization#
First, let’s visualize the ReEig eigenvalue rectification function.
fig_static, axes_static = plt.subplots(1, 2, figsize=(14, 6))
# Eigenvalue rectification function
ax1 = axes_static[0]
x = np.linspace(0, 2.5, 200)
y_reeig = np.maximum(x, epsilon)
ax1.plot(x, x, "k--", alpha=0.4, label="Identity (y=x)", linewidth=2)
ax1.plot(x, y_reeig, "b-", linewidth=3, label=f"ReEig (epsilon={epsilon})")
ax1.fill_between(
[0, epsilon],
[epsilon, epsilon],
[0, 0],
color="red",
alpha=0.15,
label="Clamped region",
)
ax1.axhline(y=epsilon, color="red", linestyle="--", alpha=0.5)
ax1.axvline(x=epsilon, color="red", linestyle="--", alpha=0.5)
# Mark example eigenvalues
for i, eigvals in enumerate(eigval_sets):
for ev in eigvals:
out_ev = max(ev, epsilon)
ax1.scatter(
[ev],
[out_ev],
s=100,
c=[colors[i]],
edgecolors="black",
linewidth=1.5,
zorder=5,
)
if ev < epsilon:
ax1.plot(
[ev, ev], [ev, epsilon], color=colors[i], linestyle=":", linewidth=1.5
)
ax1.set_xlim(-0.1, 2.5)
ax1.set_ylim(-0.1, 2.5)
ax1.set_xlabel("Input eigenvalue lambda", fontsize=12)
ax1.set_ylabel("Output eigenvalue max(lambda, epsilon)", fontsize=12)
ax1.set_title(
"ReEig: Eigenvalue Rectification Function", fontsize=13, fontweight="bold"
)
ax1.legend(loc="lower right", fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_aspect("equal")
# Before/After ellipses
ax2 = axes_static[1]
setup_spd_plot(
ax2, xlim=(-4, 4), ylim=(-3.5, 3.5), title="ReEig Effect on SPD Matrices"
)
# Positions for ellipses
y_positions = np.linspace(2, -2, n_matrices)
centers = [(0, y) for y in y_positions]
# Draw input and output side by side
x_offset = 1.5
for i, (inp, out, center) in enumerate(zip(input_matrices, output_matrices, centers)):
# Input (left)
center_in = (center[0] - x_offset, center[1])
ellipse_in = create_ellipse_patch(
inp, center_in, alpha=0.5, color=colors[i], edgecolor="black", linewidth=2
)
ax2.add_patch(ellipse_in)
# Arrow
ax2.annotate(
"",
xy=(center[0] + x_offset - 0.8, center[1]),
xytext=(center[0] - x_offset + 0.8, center[1]),
arrowprops=dict(arrowstyle="->", color="gray", lw=1.5),
)
# Output (right)
center_out = (center[0] + x_offset, center[1])
ellipse_out = create_ellipse_patch(
out, center_out, alpha=0.5, color=colors[i], edgecolor="black", linewidth=2
)
ax2.add_patch(ellipse_out)
ax2.text(-x_offset, 3.2, "Input", ha="center", fontsize=11, fontweight="bold")
ax2.text(x_offset, 3.2, "ReEig(Input)", ha="center", fontsize=11, fontweight="bold")
plt.tight_layout()

Mathematical Explanation#
ReEig introduces non-linearity while preserving SPD structure:
Eigendecomposition: \(X = U \Lambda U^T\)
Rectification: \(\Lambda' = \max(\Lambda, \epsilon I)\)
Reconstruction: \(Y = U \Lambda' U^T\)
Key properties:
Preserves eigenvectors: Only eigenvalues change
Non-expansive: \(\|Y\|_F \geq \|X\|_F\)
Gradient flow: Proper backprop through eigendecomposition
Animation#
The animation shows the ReEig eigenvalue rectification in action.
# Create animation figure
fig_anim = plt.figure(figsize=(16, 6))
# Three subplots: eigenvalue plot, input ellipse, output ellipse
ax_eigen = fig_anim.add_subplot(1, 3, 1)
ax_input = fig_anim.add_subplot(1, 3, 2)
ax_output = fig_anim.add_subplot(1, 3, 3)
# Animation parameters
n_frames = 50
pause_frames = 20
def draw_eigenvalue_plot(ax, t, highlight_idx=None):
"""Draw eigenvalue rectification plot."""
ax.clear()
ax.set_xlim(-0.5, 3)
ax.set_ylim(-0.2, 2.5)
ax.set_xlabel("Input eigenvalue lambda", fontsize=11)
ax.set_ylabel("Output eigenvalue", fontsize=11)
ax.set_title("ReEig: Eigenvalue Rectification", fontsize=12, fontweight="bold")
ax.grid(True, alpha=0.3)
# Draw threshold line
ax.axhline(y=epsilon, color="red", linestyle="--", linewidth=2, alpha=0.7)
ax.axvline(x=epsilon, color="red", linestyle="--", linewidth=2, alpha=0.7)
# Shade clamped region
ax.fill_between([0, epsilon], [epsilon, epsilon], [0, 0], color="red", alpha=0.1)
# Draw identity line (no change)
x_line = np.linspace(0, 3, 100)
ax.plot(x_line, x_line, "k--", alpha=0.3, label="Identity")
# Draw ReEig function
y_reeig = np.maximum(x_line, epsilon)
ax.plot(x_line, y_reeig, "b-", linewidth=3, label=f"ReEig(epsilon={epsilon})")
# Plot eigenvalues for each matrix
for i, eigvals in enumerate(eigval_sets):
for j, ev in enumerate(eigvals):
marker_size = (
150 if (highlight_idx is not None and i == highlight_idx) else 80
)
output_ev = max(ev, epsilon)
# Interpolate position during animation
current_y = ev + t * (output_ev - ev)
ax.scatter(
[ev],
[current_y],
s=marker_size,
c=[colors[i]],
edgecolors="black",
linewidth=1.5,
zorder=5,
)
# Draw vertical arrow showing rectification
if ev < epsilon and t > 0:
ax.annotate(
"",
xy=(ev, current_y),
xytext=(ev, ev),
arrowprops=dict(arrowstyle="->", color=colors[i], lw=1.5),
)
# Labels
ax.text(epsilon + 0.05, 0.05, f"epsilon={epsilon}", fontsize=10, color="red")
ax.legend(loc="lower right", fontsize=9)
def animate(frame):
"""Animation frame update."""
# Progress
if frame < pause_frames:
t = 0.0
phase = "Input eigenvalues"
elif frame < pause_frames + n_frames:
t = (frame - pause_frames) / n_frames
t = 0.5 * (1 - np.cos(np.pi * t)) # Smooth easing
phase = f"Applying ReEig (t={t:.2f})"
else:
t = 1.0
phase = "Rectified eigenvalues"
# Update eigenvalue plot
draw_eigenvalue_plot(ax_eigen, t)
# Update input ellipses
ax_input.clear()
setup_spd_plot(ax_input, xlim=(-3, 3), ylim=(-3.5, 3.5), title="Input SPD Matrices")
for i, (inp, center) in enumerate(zip(input_matrices, centers)):
ellipse = create_ellipse_patch(
inp,
center,
alpha=0.6,
color=colors[i],
edgecolor="black",
linewidth=2,
)
ax_input.add_patch(ellipse)
# Label eigenvalues
eigvals = eigval_sets[i]
label = f"lambda=[{eigvals[0]:.2f}, {eigvals[1]:.2f}]"
ax_input.text(
center[0] + 1.8, center[1], label, fontsize=9, va="center", color=colors[i]
)
# Mark if eigenvalues are below threshold
if np.any(eigvals < epsilon):
ax_input.plot(center[0], center[1], "rx", markersize=12, mew=2)
# Update output ellipses
ax_output.clear()
setup_spd_plot(
ax_output,
xlim=(-3, 3),
ylim=(-3.5, 3.5),
title=r"Output: $U \max(\Lambda, \epsilon) U^T$",
)
for i, (inp, out, center) in enumerate(
zip(input_matrices, output_matrices, centers)
):
# Interpolate
interp = (1 - t) * inp + t * out
ellipse = create_ellipse_patch(
interp,
center,
alpha=0.6,
color=colors[i],
edgecolor="black",
linewidth=2,
)
ax_output.add_patch(ellipse)
# Label eigenvalues
inp_eigvals = eigval_sets[i]
out_eigvals = np.maximum(inp_eigvals, epsilon)
current_eigvals = (1 - t) * inp_eigvals + t * out_eigvals
label = f"lambda=[{current_eigvals[0]:.2f}, {current_eigvals[1]:.2f}]"
ax_output.text(
center[0] + 1.8, center[1], label, fontsize=9, va="center", color=colors[i]
)
# Mark if clamping occurred
if np.any(inp_eigvals < epsilon):
ax_output.plot(center[0], center[1], "g*", markersize=12, mew=2)
# Legend
legend_elements = [
Line2D(
[0],
[0],
marker="x",
color="w",
markerfacecolor="red",
markersize=10,
label="Eigenvalue < epsilon",
),
Line2D(
[0],
[0],
marker="*",
color="w",
markerfacecolor="green",
markersize=12,
label="Rectified",
),
]
ax_output.legend(handles=legend_elements, loc="upper right", fontsize=9)
fig_anim.suptitle(f"ReEig Layer — {phase}", fontsize=14, fontweight="bold")
return []
def init():
"""Initialize animation."""
return []
# Create the animation - must be assigned to a variable that persists
total_frames = 2 * pause_frames + n_frames
anim = animation.FuncAnimation(
fig_anim, animate, init_func=init, frames=total_frames, interval=50, blit=False
)
plt.tight_layout()
plt.show()